Prediction assessment of POC attenuation from all plankton metrics

Assess the XGBoost model to predict POC attenuation values from plankton diversity data.

Author

Thelma Panaïotis

Set-up and load data

#|output: false
source("utils.R")
load("data/04.att_from_plankton_all_fit.Rdata")

Rsquares

Black dots on the R² boxplots show the actual values.

# Unnest predictions
preds <- res %>% select(fold, cv_type, preds) %>% unnest(preds)

# Compute Rsquare for each fold of each CV type
rsquares <- preds %>%
  group_by(cv_type, fold) %>%
  rsq(truth = att, estimate = .pred)

# Distribution of Rsquares by CV type
rsquares %>% split(.$cv_type) %>% map(summary)
$spatial
   cv_type              fold             .metric           .estimator       
 Length:10          Length:10          Length:10          Length:10         
 Class :character   Class :character   Class :character   Class :character  
 Mode  :character   Mode  :character   Mode  :character   Mode  :character  
                                                                            
                                                                            
                                                                            
   .estimate        
 Min.   :0.0000251  
 1st Qu.:0.0010787  
 Median :0.0541479  
 Mean   :0.1487584  
 3rd Qu.:0.2587648  
 Max.   :0.4611127  

$stratified
   cv_type              fold             .metric           .estimator       
 Length:10          Length:10          Length:10          Length:10         
 Class :character   Class :character   Class :character   Class :character  
 Mode  :character   Mode  :character   Mode  :character   Mode  :character  
                                                                            
                                                                            
                                                                            
   .estimate     
 Min.   :0.2577  
 1st Qu.:0.3442  
 Median :0.4239  
 Mean   :0.4032  
 3rd Qu.:0.4332  
 Max.   :0.5079  
rsquares %>% 
  group_by(cv_type) %>% 
  summarise(
    mean = mean(.estimate),
    sd = sd(.estimate)
  )
# A tibble: 2 × 3
  cv_type     mean     sd
  <chr>      <dbl>  <dbl>
1 spatial    0.149 0.181 
2 stratified 0.403 0.0804
# Plot Rsquares values
ggplot(rsquares) + 
  geom_boxplot(aes(x = cv_type, y = .estimate, group = cv_type, colour = cv_type)) +
  geom_jitter(aes(x = cv_type, y = .estimate), size = 0.5, width = 0.1) +
  scale_y_continuous(limits = c(0, 1), expand = c(0, 0)) +
  labs(x = "CV type", y = "R²", colour = "CV type")

Predictions VS truth

Plot pred VS truth on the test part of each fold of each CV type.

preds %>%
  ggplot() +
  geom_point(aes(x = att, y = .pred, colour = cv_type), size = 0.5) +
  geom_abline(intercept = 0, slope = 1, colour = "red") +
  coord_fixed() + 
  facet_wrap(cv_type~fold)

Now let’s focus on a representative fold for each CV type.

# Find the one closer to the median and plot it
repres_fold <- rsquares %>%
  group_by(cv_type) %>%
  mutate(diff = abs(.estimate - median(.estimate))) %>%
  filter(diff == min(diff)) %>%
  slice_head(n = 1)

repres_fold %>%
  select(cv_type, fold) %>%
  left_join(preds, by = join_by(cv_type, fold)) %>%
  ggplot() +
  geom_point(aes(x = att, y = .pred, colour = cv_type)) +
  geom_abline(intercept = 0, slope = 1, colour = "red") +
  coord_fixed() + 
  labs(title = "Pred VS truth for a representative fold") +
  facet_grid(~cv_type)

Variable importance

Variable importance for each fold of each CV type.

# Unnest variable importance
full_vip <- res %>%
  select(cv_type, fold, importance) %>%
  unnest(importance) %>%
  mutate(variable = forcats::fct_reorder(variable, dropout_loss))

# Variable importance across folds
full_vip %>%
  filter(variable != "_full_model_") %>%
  ggplot() +
  geom_vline(data = full_vip %>% filter(variable == "_full_model_"), aes(xintercept = mean(dropout_loss)), colour = "grey", linewidth = 2) +
  geom_boxplot(aes(x = dropout_loss, y = variable, colour = cv_type)) +
  labs(x = "RMSE after permutations") +
  facet_grid(fold~cv_type)

Now let’s take the mean across folds of each CV type.

full_vip %>%
  filter(variable != "_full_model_") %>%
  group_by(cv_type, fold, variable) %>%
  summarise(dropout_loss = mean(dropout_loss), .groups = "drop") %>%
  ggplot() +
  geom_vline(data = full_vip %>% filter(variable == "_full_model_"), aes(xintercept = mean(dropout_loss)), colour = "grey", linewidth = 2) +
  geom_boxplot(aes(x = dropout_loss, y = variable, colour = cv_type)) +
  labs(x = "Mean RMSE after permutations across CV folds")

Lots of variations in variable importance, likely way too much predictors. Will need to refit a model with fewer predictors.

Partial dependence plots

Finally, let’s have a look at partial dependence plots.

PDP are averaged across folds:

  • compute the mean and spread (centered cp profiles) of cp profiles within each fold

  • perform a weighted average across folds, using 1/var as weights

# Variables for which to plot pdp
n_pdp <- 3
vars_pdp <- full_vip %>%
  filter(variable != "_full_model_") %>%
  mutate(variable = as.character(variable)) %>%
  group_by(cv_type, variable) %>%
  summarise(dropout_loss = mean(dropout_loss), .groups = "drop") %>%
  arrange(desc(dropout_loss)) %>%
  group_by(cv_type) %>%
  slice_head(n = n_pdp)

# Unnest cp_profiles
cp_profiles <- res %>% select(cv_type, fold, cp_profiles) %>% unnest(cp_profiles)

## Let’s generate averaged cp profile across folds for each cv-type and propagating uncertainties. 
## The difficulty is that x values differ between each fold, the solution is to interpolate yhat on a common set of x values across folds.
## Steps as follows for each cv_type and each variable
## 1- compute the mean and spread of cp profiles within each fold
## 2- interpolate yhat value and spread within each fold using a common set of x values
## 3- perform a weighted average of yhat value and spread, using 1/var as weights

# Get names of folds, for later use
folds <- sort(unique(full_vip$fold))

# Apply on each cv_type and variable
mean_pdp <- lapply(1:nrow(vars_pdp), function(r){
  
  # Get variable and cvtype
  var_name <- vars_pdp[r,]$variable
  cv_type_name <- vars_pdp[r,]$cv_type
  
  ## Get corresponding CP profiles, compute mean and spread for each fold (step 1)
  d_pdp <- cp_profiles %>% 
    filter(cv_type == cv_type_name & `_vname_` == var_name) %>% 
    select(cv_type, fold, `_yhat_`, `_vname_`, `_ids_`, all_of(var_name)) %>% 
    arrange(`_ids_`, across(all_of(var_name))) %>% 
    # center each cp profiles across fold, variable and ids
    group_by(cv_type, fold, `_vname_`, `_ids_`) %>%
    mutate(yhat_cent = `_yhat_` - mean(`_yhat_`)) %>% # center cp profiles
    ungroup() %>%
    # compute mean and sd of centered cp profiles for each fold and value of the variable of interest
    group_by(cv_type, fold, across(all_of(var_name))) %>%
    summarise(
      yhat_loc = mean(`_yhat_`), # compute mean of profiles
      yhat_spr = sd(yhat_cent), # compute sd of cp profiles
      .groups = "keep"
    ) %>%
    ungroup() %>% 
    setNames(c("cv_type", "fold", "x", "yhat_loc", "yhat_spr"))
  
  ## Interpolate yhat values and spread on a common x distribution (step 2)
  # Regularise across folds: need a common x distribution, and interpolate y on this new x
  new_x <- quantile(d_pdp$x, probs = seq(0, 1, 0.01), names = FALSE)
  # x is different within each fold, so interpolation is performed on each fold
  
  int_pdp <- lapply(1:length(folds), function(i){
    # Get data corresponding to this fold
    fold_name <- folds[i]
    this_fold <- d_pdp %>% filter(fold == fold_name)
    
    # Extract original x values
    x <- this_fold$x
    # Extract values to interpolate (yhat_loc and yhat_spr)
    yhat_loc <- this_fold$yhat_loc
    yhat_spr <- this_fold$yhat_spr
    # Interpolate yhat_loc and yhat_spr on new x values
    int <- tibble(
      new_x = new_x,
      yhat_loc = castr::interpolate(x = x, y = yhat_loc, xout = new_x),
      yhat_spr = castr::interpolate(x = x, y = yhat_spr, xout = new_x),
    ) %>% 
      rename(x = new_x) %>% 
      mutate(
        cv_type = cv_type_name,
        fold = fold_name,
        var_name = var_name,
        .before = x
        )
    # Return the result
    return(int)
    
  }) %>% 
    bind_rows()
  
  ## Across fold, compute the weighted mean, using 1/var as weights (step 3)
  mean_pdp <- int_pdp %>% 
    group_by(cv_type, var_name, x) %>% 
    summarise(
      yhat_loc = wtd.mean(yhat_loc, weights = 1/(yhat_spr)^2),
      yhat_spr = wtd.mean(yhat_spr, weights = 1/(yhat_spr)^2),
      .groups = "drop"
    ) %>% 
    arrange(x)
  
  # Return the result
  return(mean_pdp)
}) %>% 
  bind_rows()

# Arrange in order of most important variables
mean_pdp <- vars_pdp %>% 
  rename(var_name = variable) %>% 
  left_join(mean_pdp, by = join_by(cv_type, var_name)) %>% 
  mutate(var_name = fct_inorder(var_name)) %>% 
  select(-dropout_loss)

# Plot it!
ggplot(mean_pdp) + 
  geom_path(aes(x = x, y = yhat_loc, colour = cv_type)) +
  geom_ribbon(aes(x = x, ymin = yhat_loc - yhat_spr, ymax = yhat_loc + yhat_spr, fill = cv_type), alpha = 0.2) +
  facet_wrap(~var_name, scales = "free_x")

Tendencies of best predictors

Latitude

load("data/03.all_data.Rdata")
df_trends <- df %>% 
  #mutate(log_remain = log(remain)) %>% 
  select(lon, lat, att, all_of(vars_pdp$variable))

df_trends_long <- df_trends %>% 
  select(-lon) %>% 
  pivot_longer(c(att, all_of(vars_pdp$variable))) %>% 
  mutate(name = factor(name, levels = c("att", unique(vars_pdp$variable))))

df_trends_long %>% 
  ggplot(aes(x = lat, y = value)) + 
  geom_point(size = 0.5) +
  geom_smooth() +
  facet_wrap(~name, scales = "free_x", nrow = 1) +
  coord_flip()

Maps

plot_list <- list()

vars_to_plot <- df_trends_long %>% 
  select(name) %>% 
  unique()

colour_scales <- c("asc", "asc", "div", "asc", "asc")


for (i in 1:nrow(vars_to_plot)) {
  my_var <- vars_to_plot %>% dplyr::slice(i) %>% pull(name) %>% as.character()
  
  if (colour_scales[i] == "div") {
    p <- ggmap(df_trends, var = my_var, type = "point", palette = div_pal)
  } else {
    p <- ggmap(df_trends, var = my_var, type = "point")  
  }
  
  plot_list[[i]] <- p
}

plot_list
[[1]]


[[2]]


[[3]]


[[4]]


[[5]]